# -*- coding: utf-8 -*-

# 通过对 torch 的 Tensor 对象使用切片操作可以提取感兴趣的数据。

from torch import Tensor

batch_data = Tensor(
    [
        [
            [1, 2],
            [3, 4],
            [5, 6]     # 只要这个
        ],
        [
            [7, 8],
            [9, 10],
            [11, 12]     # 只要这个
        ]
    ]
)
print('batch_data:', batch_data)

# 演示只要 [5, 6] 和 [11, 12] 时可以怎么做：

extracted = batch_data[:, -1:]
# 说明：[:, -1:]，对列表中的每一个元素（这些元素类型都是列表）都保留最后一个元素，也就是保留[5,6]和[11,12]
print('extracted:', extracted)

# 执行代码输出：
# batch_data: tensor([[[ 1.,  2.],
#          [ 3.,  4.],
#          [ 5.,  6.]],
#
#         [[ 7.,  8.],
#          [ 9., 10.],
#          [11., 12.]]])
# extracted: tensor([[[ 5.,  6.]],
#
#         [[11., 12.]]])
